import unittest

import tensorflow as tf

from model.tf.mobilenet_v2 import MobileNetV2

HEIGHT, WIDTH, DEPTH = 32, 32, 3


class MyTestCase(unittest.TestCase):
    def setUp(self) -> None:
        gpus = tf.config.experimental.list_physical_devices('GPU')
        if gpus:
            try:
                tf.config.experimental.set_virtual_device_configuration(
                    gpus[0],
                    [tf.config.experimental.VirtualDeviceConfiguration(memory_limit=1024)])
            except RuntimeError as e:
                print(e)

    def test_mobilenet_v2(self):
        mobilenet_v2 = MobileNetV2(num_classes=200)
        mobilenet_v2(tf.keras.layers.Input([HEIGHT, WIDTH, DEPTH]))
        print(mobilenet_v2.summary())
        self.assertEqual((32, 200), mobilenet_v2.compute_output_shape([32, HEIGHT, WIDTH, DEPTH]))
